# scripts/step3_sr.py
import argparse
import os
import json
import math
import numpy as np
import pandas as pd
from time import perf_counter

from src.present_act.gates import ThetaLadder, KappaLadder, StructuralGates, CRA
from src.present_act.lints import Lints
from src.present_act.engine import PresentActEngine, RunManifest
from src.present_act.scenes import make_optics_scene, make_optics_roi, place_sources_from_s
from src.present_act.sr import SRBudget, GammaCalc
from scripts._util import ensure_out, write_md, load_cfg


def run_sr_for_alpha(scene, c, alpha_req, shots=800, seed=101, kappa_level=2):
    """
    Boolean drift: per shot choose forward rail with prob p=alpha_req, else back rail.
    Dynamic Θ depth (with headroom) per scene so BFS reaches ROI midline.
    Per-accept budgets: Δτ=1, Δx=±1 (sign by drift), Δt = sqrt(1+(Δx/c)^2).
    """
    # Geometry
    x0, y0, x1, y1 = scene.roi_bbox
    y_mid = (y0 + y1) // 2

    # Place far-apart rails so forward/back are cleanly separated
    (xL, y_sL), (xR, y_sR) = place_sources_from_s(scene, s=int(0.8 * (scene.W // 2)), y_row=scene.H // 4)

    # --- Dynamic Theta: exact reach distance + headroom ---
    dist = abs(y_sL - y_mid)
    theta_bins = [max(1, dist + 5), dist + 7, dist + 9]

    theta = ThetaLadder(theta_bins)
    kappa = KappaLadder([kappa_level])
    man = RunManifest(theta, kappa, StructuralGates(), CRA(True), Lints(), seed=seed, c_units=c)
    screen = make_optics_roi(scene)
    eng = PresentActEngine(scene, man)

    budget = SRBudget()
    commits = 0
    roi_mid = y_mid

    for i in range(shots):
        forward = (np.random.rand() < alpha_req)
        src = (xR, y_sR) if forward else (xL, y_sL)  # choose rail by Bernoulli(p=alpha_req)
        cands = eng.propose_candidates([src], screen)
        acc, _ = eng.accept(cands)

        if acc is not None:
            x, y = acc
            # optional: mark at midline for visibility
            screen[roi_mid, x] += 1

            # per-accept budgets
            d_tau = 1.0
            d_x   = (1.0 if forward else -1.0)
            d_t   = math.sqrt(d_tau*d_tau + (d_x*d_x)/(c*c))
            budget.sum_dtau += d_tau
            budget.sum_dt   += d_t
            budget.sum_dx   += d_x
            commits += 1

    alpha_emp  = GammaCalc.alpha(budget.sum_dx, c, budget.sum_dt)
    gamma_emp  = GammaCalc.gamma(budget.sum_dtau, budget.sum_dt)

    return {
        "alpha_req": alpha_req,
        "alpha_emp": alpha_emp,
        "gamma_emp": gamma_emp,
        "sum_dt": budget.sum_dt,
        "sum_dtau": budget.sum_dtau,
        "sum_dx": budget.sum_dx,
        "commits": commits,
        "theta_bins": theta_bins
    }


def main(cfg):
    out3 = ensure_out("out", "step3")

    # Load calibrated c
    try:
        with open(os.path.join("out", "calibration_time_hinge.json"), "r") as f:
            c = float(json.load(f).get("c", 1.0))
    except Exception:
        c = 1.0

    # Config
    seeds  = cfg["common"]["seeds"]
    shots  = int(cfg["common"]["shots"])
    kappa_level = int(cfg.get("sr", {}).get("kappa_level", 2))
    alphas = cfg.get("sr", {}).get("alphas", [0.0, 0.25, 0.5, 0.75, 0.9])

    # Use a mid-size container for SR
    L = 512
    scene = make_optics_scene(L, w=int(cfg["scene"]["w_inner_px"]))

    rows = []
    print(f"[SR] c={c:.6f}, L_out={L}, kappa={kappa_level}, shots={shots}", flush=True)
    for seed in seeds:
        for a in alphas:
            t0 = perf_counter()
            r = run_sr_for_alpha(scene, c, a, shots=shots, seed=seed, kappa_level=kappa_level)
            r.update({"seed": seed, "L_out": L, "c_used": c, "kappa_level": kappa_level})
            rows.append(r)
            print(f"[SR] seed={seed} a_req={a:.2f} commits={r['commits']}"
                  f" a_emp={r['alpha_emp']:.4f} g_emp={r['gamma_emp']:.4f}"
                  f" θ={r['theta_bins']} elapsed={perf_counter()-t0:.2f}s", flush=True)

    DF = pd.DataFrame(rows)
    # Theory gamma from empirical alpha
    DF["gamma_theory"] = DF["alpha_emp"].apply(lambda a: 1.0/((1.0-a*a)**0.5) if a<1 else float("inf"))
    DF.to_csv(os.path.join(out3, "sr_gamma.csv"), index=False)

    # RMSE
    rmse = float(np.sqrt(((DF["gamma_emp"] - DF["gamma_theory"])**2).mean()))
    write_md(os.path.join("out", "RESULTS_SR.md"),
             f"# STEP 3 — SR (boolean drift + dynamic Θ)\nRMSE_gamma={rmse:.4f}\n")
    print(f"SR step complete. RMSE_gamma= {rmse:.4f}")


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", default="configs/low_compute.yaml")
    args = ap.parse_args()
    cfg = load_cfg(args.config)
    main(cfg)
